import torch
import copy
import numpy as np
import time 
import scipy.stats as stats 
import math
import safetensors

def combination(n, k):
    if k < 0 or k > n:
        return 0.
    return float(math.factorial(n) // math.factorial(k) // math.factorial(n - k))

def cal_s_norm_normal(K, mu, sigma, e_min, e_max, b):
    if isinstance(mu, torch.Tensor):
        sum = torch.zeros(size=mu.shape)
    else:
        sum = 0.
    e_pro = {}
    for e in range(e_min, e_max+1):
        event_1 = stats.norm.cdf(2**e,mu,sigma) - stats.norm.cdf(2**(e-1),mu,sigma) + stats.norm.cdf(-2**(e-1),mu,sigma) - stats.norm.cdf(-2**e,mu,sigma)
        event_2 = stats.norm.cdf(2**(e-1),mu,sigma) - stats.norm.cdf(-2**(e-1),mu,sigma)
        sum_e = 0.
        for i in range(K):
            e_num = i+1
            sum_e += combination(K,e_num)*(event_1**(e_num))*(event_2**(K-e_num))
        e_pro[e] = sum_e
        sum += (2**(e-b+1))*sum_e
    return sum, e_pro


def F(x,mu,lamda):
    if x <= mu:
        return 0.5 * math.exp((x-mu)/lamda)
    else:
        return 1 - 0.5 * math.exp(-(x-mu)/lamda) 

def cal_s_norm_laplace(K, mu, sigma, e_min, e_max, b):
    lamda = sigma/(2**(1/2))
    if isinstance(mu, torch.Tensor):
        sum = torch.zeros(size=mu.shape)
    else:
        sum = 0.
    e_pro = {}
    for e in range(e_min, e_max+1):
        event_1 = F(2**e,mu,lamda) - F(2**(e-1),mu,lamda) + F(-2**(e-1),mu,lamda) - F(-2**e,mu,lamda)
        event_2 = F(2**(e-1),mu,lamda) - F(-2**(e-1),mu,lamda)
        sum_e = 0.
        for i in range(K):
            e_num = i+1
            sum_e += combination(K,e_num)*(event_1**(e_num))*(event_2**(K-e_num))
        e_pro[e] = sum_e
        sum += (2**(e-b+1))*sum_e
    return sum, e_pro


def exp_pro(e_pro_A, e_pro_W):
    sum_dict = {}
    for i in range(-20,21):
        sum_dict[i] = 0.
    for key_A in e_pro_A:
        for key_W in e_pro_W:
            sum_dict[key_A + key_W] += e_pro_A[key_A] * e_pro_W[key_W]
    for i in range(-20,21):
        sum_dict[i] = sum_dict[i].mean()
    return sum_dict

bit = 8

list_k = [x for x in range(2,32,2)]
list_k += [x for x in range(32,64,4)]
list_k += [x for x in range(64,128,8)]
list_k += [x for x in range(128,256,16)]
list_k += [x for x in range(256,512,32)]
list_k += [x for x in range(512,1088,64)]
# list_k = [8,16,32,64,128]
result_dict = {}
for k in list_k:
    result_dict[k] = 0

with safetensors.safe_open("./BERT-SQuAD/outputs/bert-based/model.safetensors", framework="pt", device='cpu') as f:
    for k in f.keys():
        if 'layer' not in k:
            continue
        weight_ = f.get_tensor(k)
        if len(weight_.shape) == 1:
            continue
        weight_mean = weight_.mean()
        weight_var = weight_.var()
        for k in list_k:
            predict_scale_A, e_pro_A = cal_s_norm_laplace(k, 0, 1, -50, 50, bit)
            predict_scale_W, e_pro_W = cal_s_norm_laplace(k, weight_mean, (0.5*weight_var)**(1/2), -50, 50, bit)
            predict_sigma = (weight_var / (predict_scale_W**2) + weight_mean**2/(predict_scale_W**2))**(1/2)/predict_scale_A*(2**(1/2))
            predict_scale_3sigma = (3*k*predict_sigma)
            predict_result = torch.log2(predict_scale_3sigma).ceil() + 1
            if predict_result > result_dict[k]:
                result_dict[k] = int(predict_result)
        print(result_dict)
print(result_dict)